import os, sys, time, argparse
import math
import random
from easydict import EasyDict as edict
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from pathlib import Path
lib_dir = (Path(__file__).parent / 'lib').resolve()
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from datasets import get_datasets, get_nas_search_loaders
from procedures import prepare_seed, prepare_logger
from procedures import Linear_Region_Collector, get_ntk_n, get_ntk_n_zen
from utils import get_model_infos
from log_utils import time_string
from models import get_cell_based_tiny_net, get_search_spaces  # , nas_super_nets
from nas_201_api import NASBench201API as API
from pdb import set_trace as bp
import json


api = API('/home/chenxinyu/Codes/NAS/ETENAS-master/NAS_data/NAS-Bench-201-v1_1-096897.pth', verbose=False)

with open('/home/chenxinyu/Codes/NAS/TENAS-main/genotype.txt', 'r') as file:
    # Read each line in the file
    info = []
    for arch in file:
        index = api.query_index_by_arch(arch.strip())
        results = api.query_by_index(index, 'cifar100')
        for seed, result in results.items():
            for i in range(200):
                if result.get_train(i)['loss'] < 0.8: # 1.5
                    info.append(result.get_train(i))
                    print(i)
                    break
            if i == 199:
                info.append(result.get_train(199))
                print(199)
            break
    # print(info)
    
    with open('output_8.json', 'w') as file:
        json.dump(info, file, indent=4)
'''
with open('output_15.json', 'r') as file:
    info_15 = json.load(file)
# info_15 = json.loads('/home/chenxinyu/Codes/NAS/TENAS-main/output_15.json')
for i in range(5):
    print(info_15[i]['all_time'])
'''